import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fsrl.utils import DummyLogger, WandbLogger
from tqdm.auto import trange  # noqa

from osrl.common.net import mlp
from torch.distributions.normal import Normal


class RTG_model(nn.Module):
    """
    CTG conditioned RTG model, predicting RTG based on CTG or CTG and current state
    
    """

    def __init__(self,
                 state_dim: int,
                 prompt_dim: int = 32,
                 cost_embedding_dim: int = 128,
                 state_embedding_dim: int = 128,
                 prompt_embedding_dim: int = 128,
                 r_hidden_sizes: list = [128, 128],
                 use_state: bool = False,
                 use_prompt: bool = False):

        super().__init__()
        self.state_dim = state_dim
        self.prompt_dim = prompt_dim
        self.use_state = use_state
        self.use_prompt = use_prompt

        self.mlp_input_dim = cost_embedding_dim
        
        self.r_hidden_sizes = r_hidden_sizes

        self.cost_embedding = nn.Linear(1, cost_embedding_dim)
        if self.use_state:
            self.state_embedding = nn.Linear(state_dim, state_embedding_dim)
            self.mlp_input_dim += state_embedding_dim
        if self.use_prompt:
            self.prompt_embedding = nn.Linear(prompt_dim, prompt_embedding_dim)
            self.mlp_input_dim += prompt_embedding_dim
            
        self.mlp_dims = [self.mlp_input_dim] + list(self.r_hidden_sizes)
        self.mlp_net = mlp(self.mlp_dims, nn.ReLU)
        self.mu_layer = nn.Linear(self.r_hidden_sizes[-1],1)
        self.logstd_layer = nn.Linear(self.r_hidden_sizes[-1],1)
        

    def forward(self, costs, states=None, prompts=None):
        if prompts is not None and len(prompts.shape)==1:
            prompts = prompts.unsqueeze(0)
        costs_embedding = self.cost_embedding(costs.unsqueeze(-1))
        seq=[costs_embedding]
        if self.use_state:
            states_embedding = self.state_embedding(states)
            seq.append(states_embedding)
        if self.use_prompt:
            prompts_embedding = self.prompt_embedding(prompts)
            seq.append(prompts_embedding)
        seq_input = torch.cat(seq, dim=-1)
        seq_output = self.mlp_net(seq_input)
        mu = self.mu_layer(seq_output)
        log_std = self.logstd_layer(seq_output)
        log_std = torch.clamp(log_std, -20, 2)
        std = torch.exp(log_std)
        rtg_dis = Normal(mu, std)
        return mu, rtg_dis, std

class MTRTG(nn.Module):
    def __init__(self,
                 RTG_model,
                 state_encoder_ls):

        super().__init__()
        self.rtg_model = RTG_model
        self.state_ae_ls = nn.ModuleList(state_encoder_ls)
        self.use_state = self.rtg_model.use_state
        self.use_prompt = self.rtg_model.use_prompt

    def forward(self, costs, states=None, prompts=None, task_id=0):
        if states is not None and self.use_state:
            states = self.state_ae_ls[task_id].encode(states)
        mu, rtg_dis, std = self.rtg_model(costs, states, prompts)
        return mu, rtg_dis, std

class RTGTrainer:
    """
    CTG Conditioned RTG Trainer

    """

    def __init__(
            self,
            model,
            logger: WandbLogger = DummyLogger(),
            # training params
            learning_rate: float = 1e-4,
            device="cpu",
            logprob_loss: bool = False) -> None:
        self.model = model
        self.logger = logger
        self.device = device
        self.logprob_loss = logprob_loss

        self.optim = torch.optim.Adam(
            self.model.parameters(),
            lr=learning_rate
        )

    def train_one_step(self, rtg, ctg, states=None, prompts=None):
        # True value indicates that the corresponding key value will be ignored
        rtg_preds, rtg_dis, std = self.model(
            ctg,
            states = states,
            prompts = prompts
        )
        if self.logprob_loss:
            rtg_loss = -rtg_dis.log_prob(rtg).mean()
            # if rtg_loss > 1e10:
            #     rtg_loss = F.mse_loss(rtg_preds, rtg, reduction="mean")
        else:
            rtg_loss = F.mse_loss(rtg_preds, rtg, reduction="mean")
        # print(rtg_loss)
        # # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
        # act_loss = (act_loss * mask.unsqueeze(-1)).mean()
        loss = rtg_loss

        self.optim.zero_grad()
        loss.backward()

        # nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2)
        self.optim.step()

        self.logger.store(
            tab="train",
            rtg_loss=rtg_loss.item(),
        )
        return rtg_loss.item()

    def eval_one_step(self, rtg, ctg, states=None, prompts=None):
        # True value indicates that the corresponding key value will be ignored
        self.model.eval()
        rtg_preds, rtg_dis, std = self.model(
            ctg,
            states = states,
            prompts = prompts
        )
        rtg_loss = F.mse_loss(rtg_preds, rtg, reduction="mean")
        # # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
        # act_loss = (act_loss * mask.unsqueeze(-1)).mean()

        self.logger.store(
            tab="eval",
            rtg_loss=rtg_loss.item(),
        )
        self.model.train()
        return rtg_loss.item()


class MTRTGTrainer:
    """
    CTG Conditioned RTG Trainer

    """

    def __init__(
            self,
            model,
            logger: WandbLogger = DummyLogger(),
            # training params
            learning_rate: float = 1e-4,
            device="cpu",
            logprob_loss: bool = False) -> None:
        self.model = model
        self.logger = logger
        self.device = device
        self.logprob_loss = logprob_loss

        self.optim = torch.optim.Adam(
            self.model.parameters(),
            lr=learning_rate
        )

    def train_one_step(self, rtg, ctg, task_name, task_id, states=None, prompts=None):
        # True value indicates that the corresponding key value will be ignored
        mu, rtg_dis, std = self.model(
            ctg,
            states = states,
            prompts = prompts,
            task_id = task_id
        )
        if self.logprob_loss:
            rtg_loss = -rtg_dis.log_prob(rtg).mean()
            # if rtg_loss > 1e10:
            #     rtg_loss = F.mse_loss(rtg_preds, rtg, reduction="mean")
        else:
            rtg_preds = rtg_dis.rsample()
            rtg_loss = F.mse_loss(rtg_preds, rtg, reduction="mean")
        # print(rtg_loss)
        # # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
        # act_loss = (act_loss * mask.unsqueeze(-1)).mean()
        loss = rtg_loss

        self.optim.zero_grad()
        loss.backward()

        # nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10, norm_type=2)
        self.optim.step()

        self.logger.store(
            tab="train/"+task_name,
            rtg_loss=rtg_loss.item(),
        )
        return rtg_loss.item()

    def eval_one_step(self, rtg, ctg, task_name, task_id, states=None, prompts=None):
        # True value indicates that the corresponding key value will be ignored
        self.model.eval()
        mu, rtg_dis, std = self.model(
            ctg,
            states = states,
            prompts = prompts,
            task_id = task_id
        )
        rtg_loss = F.mse_loss(mu, rtg, reduction="mean")
        # # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
        # act_loss = (act_loss * mask.unsqueeze(-1)).mean()

        self.logger.store(
            tab="eval/"+task_name,
            rtg_loss=rtg_loss.item(),
        )
        self.model.train()
        return rtg_loss.item()
